-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TRN2 Meshes and Configurations #916
base: main
Are you sure you want to change the base?
TRN2 Meshes and Configurations #916
Conversation
6b404f6
to
3f7c840
Compare
Added a ModelConfigModifier that overrides the class for a module. Allowing different model configurations based on Model size and platform. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for making such change, overall looks good. A few nit comments.
continue | ||
# Here we assume x.y.z format. | ||
# One example would be model.decoder.transformer.layer. | ||
target_modules = module_name.split(".") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try to extract a common util function named something like
def replace_module_recursive(target_modules:str, config_key: str, target_config)
and make it applied to both here and RematSpecModifier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I extracted a helper function, let me know if this looks good
708fc5e
to
d481132
Compare
Added |
5be50d7
to
9b10041
Compare
|
||
found_module, parent_module, key_in_parent = find_target_module(module_name, cfg) | ||
|
||
# Copy configurations from the config being replaced on a best effort basis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, this behavior is not explained in the class comments. So we are not replacing but merging the configs? Maybe we should support a merge function instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the goal is to change the config to a similar module. This means most of the configuration can be reused from before. Essentially replacing the module but merging the config. Let me extract out a merge function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Abstracted out a merge function let me know if more changes are needed for this.
9b10041
to
0f0a530
Compare
@ruomingp Thank you for the review, I have addressed all your comments, please let me know if more changes are needed. |
for module_name, model_cfg in self._model_cfg_modifications.items(): | ||
found_module = _find_target_module(module_name, cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In utils.py we have get_recursively
and set_recursively
for Nested[...]
. I wonder if it will be useful to add corresponding methods to ConfigBase. Then we can do something like:
for module_name, model_cfg in self._model_cfg_modifications.items(): | |
found_module = _find_target_module(module_name, cfg) | |
for cfg_path, cfg_modification in self._model_cfg_modifications.items(): | |
child_cfg = cfg.get_recursively(cfg_path) | |
child_cfg = cfg_modification(child_cfg, path=cfg_path) | |
cfg.set_recursively(cfg_path, value=child_cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added get_recursively and set_recursively functions to ConfigBase. Let me know if it looks good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if an alternative (which aims to simplify the ConfigBase
api) is to do something similar to Python's sorted
; we allow utils.get_resursively
to take a value fn:
# Default behavior is to use key lookup:
utils.get_recursively(..., value_fn=lambda k,v: v[k])
# Custom behavior can be attribute lookup:
utils.get_recursively(..., value_fn=lambda k,v: getattr(v,k))
A benefit is that other non-config instances can also leverage get_recursively
.
c23e3b2
to
94bfff6
Compare
Added a more flexible |
45c7df1
to
8807856
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly lgtm, some minor comments.
8807856
to
25510d6
Compare
eec33eb
to
86bafa8
Compare
4661492
to
fe96240
Compare
fe96240
to
da90757
Compare
da90757
to
7e2e5f2
Compare
@ruomingp and @kelvin-zou thank you both for the review. I addressed all comments, please let me know if anymore changes are needed. PR looks clean now. |
key: str | ||
|
||
def recursive_traverse(self, key_path: Sequence[str]) -> tuple[Any, str]: | ||
"""Recursively traverse the config to find the target key. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see other comment re get_recursively
; also, I wonder whether we actually need recursion here (seems like a loop would be simpler).
@@ -146,6 +137,110 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: | |||
return cfg | |||
|
|||
|
|||
class ModelConfigModifier(ConfigModifier): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which part of this class is specific to model? It seems to take generic modifications?
"""Configure ModelConfigModifier. | ||
|
||
Attributes: | ||
model_cfg_modifications: A mapping from module path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outdated?
"""Merge configurations from the config being replaced on a best effort basis. | ||
|
||
Merge Rules: | ||
- Klass is not changed, use target cfg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Klass is not changed, use target cfg | |
- Klass is not changed, use target cfg. |
Please end all sentences with punctuations.
target_cfg: configuration that will replace found_module. | ||
found_module: existing configuration whose class will be replaced |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
target_cfg: configuration that will replace found_module. | |
found_module: existing configuration whose class will be replaced | |
target_cfg: Configuration that will replace found_module. | |
found_module: Existing configuration whose class will be replaced |
if version != Version.V1: | ||
trn2_model_modifications.append( | ||
ModelConfigModifier.default_config().set( | ||
target_config="model.decoder.transformer.layer.self_attention.attention." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A downside of representing these deeply nested configs as string paths is that they are brittle, and can quickly become outdated.
Have we considered using cfg.visit
to achieve some of these modifications (e.g.,
axlearn/axlearn/common/layers.py
Line 266 in 1c22688
def set_layer_norm_eps_recursively(cfg: ConfigBase, eps: float, set_only_if_none: bool = False): |
(A bit late to review, so apologies if this discussion has already taken place.)
# The key string. | ||
key: str | ||
|
||
def recursive_traverse(self, key_path: Sequence[str]) -> tuple[Any, str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be a public method?
traverse_result = self.recursive_traverse(key_path) | ||
return getattr(traverse_result.parent, traverse_result.key) | ||
|
||
def set_recursively(self, key_path: Sequence[str], new_value: Any): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please name consistently with
axlearn/axlearn/common/utils.py
Lines 907 to 910 in a854738
def set_recursively( | |
x: NestedTensor, | |
*, | |
value: Tensor, |
def set_recursively(self, key_path: Sequence[str], new_value: Any): | |
def set_recursively(self, path: Sequence[str], *, value: Any): |
value = getattr(self, target_key) | ||
return value.recursive_traverse(key_path[1:]) | ||
|
||
def get_recursively(self, key_path: Sequence[str]) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_recursively(self, key_path: Sequence[str]) -> Any: | |
def get_recursively(self, path: Sequence[str]) -> Any: |
"""Recursively find the target key in the config and return its value. | ||
|
||
Args: | ||
key_path: A sequence of keys for indexing to get the target value. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can path be empty? Maybe it can return self
if path is empty?
"""Recursively find the target key in the config and set its value. | ||
|
||
Args: | ||
key_path: A sequence of keys for indexing to set the target value. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can path be empty?
Raises: | ||
ValueError: A key in key_path is not found. | ||
""" | ||
traverse_result = self.recursive_traverse(key_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do something like:
traverse_result = self.recursive_traverse(key_path) | |
if not path: | |
raise ValueError(...) | |
parent = self.get_recursively(path[:-1]) |
This PR adds meshes for TRN2/1 for Fuji models and transformer layer configuration favorable to Neuron.
Neuron supports stacked transformer and GroupedQKVLinear instead of FusedGroupedQKVLinear for Grouped Query Attention (GQA)
This is a newer version of the PR #885. This PR resolved all comments and requested changes mentioned in the linked PR.